from abc import ABC

import tensorflow as tf
from db_config import cfg


class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule, ABC):
    def __init__(self):
        super(CustomSchedule, self).__init__()
        self.init_lr = cfg.TRAIN.LEARNING_RATE
        self.max_steps = cfg.TRAIN.MAX_STEPS
        self.power = cfg.TRAIN.POWER

    def __call__(self, step):
        return self.init_lr * ((1 - (step/self.max_steps))**self.power)